from dataset.getdataset_DP import GetDataSetDP
from dataset.getdataset_MSB import GetDataSetMSB
from dataset.getdataset_GC import GetDataSetGC
from simulator import Simulator
import torch
import time
from utils import NodeType
import matplotlib.pyplot as plt
import numpy as np
import math
import copy
import threading
import h5py
from valid import rollout
    
import os


#model_setting and device_setting, based on |model_name|
model_name = "MAVEN"
model_set = ""

dataset_name = "MSB"
device = torch.device(f'cuda:2')

if (dataset_name == "DP"):  
    dataset_dir = "/meshgraphnet_data/deepmind_h5/deforming_plate/"
    layer_r = 1
elif(dataset_name == "MSB"):
    dataset_dir = "/strip_simple/"
    layer_r = 4
elif (dataset_name == "GC"):  
    dataset_dir = "/cavity_grasping_dataset/"
    layer_r = 1

#training setting and loading
start_lr = 1e-4
end_lr = 1e-5
batch_size = 8
max_steps = 1000000
start_steps = 0
save_steps = 100000
check_steps = 100000
print_steps = 100

config = f"config/config_{model_name}_{dataset_name}.yml"

simulator = Simulator(config, device = device).to(device)
tot_params = sum(p.numel() for p in simulator.parameters())
print(tot_params)

g = 1.0 / pow(start_lr / end_lr, 1 / (max_steps))
p = start_lr * pow(g, start_steps)

optimizer = torch.optim.Adam(simulator.parameters(), lr=p, weight_decay = 1e-4)       
scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=g)
    
def train(model, dataloader, valid_dataset):
    
    model.train()
      
    now_steps = start_steps
    for epoch in range(max_steps):
        
        epoch_training_loss = 0.0
        epoch_training_cnt = 0.0    
            
        for batch_index, mydata in enumerate(dataloader):

            data_list = copy.deepcopy(mydata)
            for key in data_list.keys():
                data_list[key] = data_list[key].to(device)          

            predicted_velo, target_velo = model(data_list, True)
            node_type = data_list["node_type"] 
            mask = (node_type==NodeType.NORMAL)
            errors = ((predicted_velo - target_velo)**2)[mask.reshape(-1), :]
            loss = torch.mean(errors)    
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            scheduler.step()
            
            now_steps = now_steps + 1    
            epoch_training_loss += loss.item()
            epoch_training_cnt += 1
            if (now_steps) % print_steps == 0:
                print(f'nowsteps {now_steps}, model_name: {model_name}, training_loss = {loss.item()}')

            if ((now_steps) % save_steps == 0):
                print("Save model")
                model.save_checkpoint(savedir = f"checkpoint/{dataset_name}-{model_name}{model_set}-{(now_steps):02d}.pth") 
                print("Save Successful")
                    
            if ((now_steps) % check_steps == 0):
                print("Check model on valid set")
                valid_pos, valid_stress = rollout(model, valid_dataset, device)
                print(f"Check Successful on step {now_steps}, valid_pos_error: {valid_pos}, valid_stress_error: {valid_stress}")     
                model.train()
            if ((now_steps) % max_steps == 0):
                break
        
        print(f"epoch: {epoch}, epoch_avg_training_loss:{epoch_training_loss / epoch_training_cnt}")    
            
        if ((now_steps) % max_steps == 0):
            break   

if __name__ == '__main__':
    if (dataset_name == "MSB"):
        dataset = GetDataSetMSB(dataset_dir=dataset_dir, split='train', batch_size = batch_size)
        valid_dataset = GetDataSetMSB(dataset_dir=dataset_dir, split='valid', batch_size = 1)
    elif (dataset_name == "DP"):
        dataset = GetDataSetDP(dataset_dir=dataset_dir, split='train', batch_size = batch_size)
        valid_dataset = GetDataSetDP(dataset_dir=dataset_dir, split='valid', batch_size = 1)   
    elif (dataset_name == "GC"):
        dataset = GetDataSetGC(dataset_dir=dataset_dir, split='train', batch_size = batch_size)
        valid_dataset = GetDataSetGC(dataset_dir=dataset_dir, split='valid', batch_size = 1)   



    import sys
    log_path = f"results/log_{dataset_name}_{model_name}{model_set}.log"
    log_writer = open(log_path, 'a', encoding = "utf8", buffering = 1)

    sys.stdout = log_writer
    sys.stderr = log_writer
    
    train(simulator, dataset, valid_dataset)

